-
Notifications
You must be signed in to change notification settings - Fork 884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/conformal prediction #2552
base: master
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2552 +/- ##
==========================================
- Coverage 94.14% 94.12% -0.02%
==========================================
Files 139 140 +1
Lines 14884 15311 +427
==========================================
+ Hits 14013 14412 +399
- Misses 871 899 +28 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work @dennis!
Some small comments, mostly documentation.
@@ -787,16 +790,19 @@ def merr( | |||
Returns | |||
------- | |||
float | |||
A single metric score for: | |||
A single metric score for (with `len(q) <= 1`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wording is a bit confusing here, I would maybe order the words a bit differently
A single metric score for (with `len(q) <= 1`): | |
A single metric score (when `len(q) <= 1`) for: |
|
||
- single univariate series. | ||
- single multivariate series with `component_reduction`. | ||
- sequence (list) of uni/multivariate series with `series_reduction` and `component_reduction`. | ||
- a sequence (list) of uni/multivariate series with `series_reduction` and `component_reduction`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add an "a" here, we should add it to the bullets points above as well to stay consistent
|
||
- the input from the `float` return case above but with `len(q) > 1`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't find this bullet point very clear; you mean that when len(q) > 1
, the bullets points from above are also applicable here? I wonder if it would not be just clearer to repeat them here as well?
For time series that are overlapping in time without having the same time index, setting `True` | ||
will consider the values only over their common time interval (intersection in time). | ||
q_interval | ||
The quantile interval(s) to compute the metric on. Must be a tuple (single interval) or sequence tuples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The quantile interval(s) to compute the metric on. Must be a tuple (single interval) or sequence tuples | |
The quantile interval(s) to compute the metric on. Must be a tuple (single interval) or sequence of tuples |
for q_high, q_low in zip( | ||
self.quantiles[self.idx_median + 1 :][::-1], | ||
self.quantiles[: self.idx_median], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be a bit simpler to just reuse the list of tuples stored in self.q_interval
instead of iterating on the array again.
model.predict(n=1) | ||
|
||
pred = model.predict(n=self.horizon, series=self.ts_pass_train, **pred_lklp) | ||
assert pred.n_components == self.ts_pass_train.n_components * 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert pred.n_components == self.ts_pass_train.n_components * 3 | |
assert pred.n_components == self.ts_pass_train.n_components * len(kwargs["quantiles"]) |
len(pred_list) == 2 | ||
), f"Model {model_cls} did not return a list of prediction" | ||
for pred, pred_fc in zip(pred_list, pred_fc_list): | ||
assert pred.n_components == self.ts_pass_train.n_components * 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert pred.n_components == self.ts_pass_train.n_components * 3 | |
assert pred.n_components == self.ts_pass_train.n_components * len(kwargs["quantiles"]) |
with pytest.raises(ValueError): | ||
covs = cov_kwargs_train[cov_name] | ||
covs = {cov_name: covs.stack(covs)} | ||
_ = model.predict(n=OUT_LEN + 1, **covs, **pred_lklp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To eliminate the other possible source of error, n should be OUT_LEN
with pytest.raises(ValueError): | ||
covs = cov_kwargs_notrain[cov_name] | ||
covs = {cov_name: covs.stack(covs)} | ||
_ = model.predict(n=OUT_LEN + 1, **covs, **pred_lklp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n should be OUT_LEN
.
model = ConformalNaiveModel(model=train_model(series), quantiles=quantiles) | ||
# direct quantile predictions | ||
pred_quantiles = model.predict(n=3, series=series, **pred_lklp) | ||
# smapled predictions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# smapled predictions | |
# sampled predictions |
Checklist before merging this PR:
Fixes #1704, fixes #2161.
Short Summary
ConformalNaiveModel
, andConformalQRModel
(read more below).iws()
, and Mean Interval Winkler Scoresmiws()
(time-aggregated) (source)ic()
(binary if observation is within the quantile interval), and Mean Interval Covaragemic()
(time-aggregated)incs_qr()
, and Mean ...mincs_qr()
(time-aggregated) (source)overlap_end=True
inForecastingModel.residuals()
. This computes historical forecasts and residuals that can extend further than the end of the target series. With this, all returned residual values have the same length per forecast (the last residuals will contain missing values, if the forecasts extended further into the future than the end of the target series).Summary
Adds first conformal prediction models to Darts. Conformal models can be applied to any of Darts' global forecasting model, as long as the model has been fitted before. In general the workflow of the models to produce one forecast/prediction is as follows:
The number of calibration examples from the most recent past to use for one conformal prediction can be defined at model creation with parameter
cal_length
. To make your life simpler, we support two modes:series
,past_covariates
, ...). This is the default mode and our predict/forecasting/backtest/.... API is identical to any other forecasting modelcal_series
,cal_past_covariates
, ... .quantiles
).Input Support
All added conformal models support the following input (depending on the fitted forecasting model):
Forecast/Output Support
All models support the following prediction modes:
quantiles=[0.05, 0.2, 0.5, 0.8, 0.95]
).cal_length
(to make the algorithm adaptive)predict_likelihood_parameters=True, num_samples=1
in all prediction methods.num_samples>>1
in all prediction methods.Requirements to use a conformal model:
n
. It must be possible to generate at leastn + cal_length
historical forecasts from the calibration input series.Added Algorithms
Added two algorithms each with two symmetry modes:
ConformalNaiveModel
: Adds calibrated intervals around the median forecast from the forecasting model.symmetric=True
:ae()
(absolute error) to compute the non-conformity scoressymmetric=False
err()
(error) to compute the non-conformity scores of the upper bounds, an-err()
for the lower bounds.ConformalQRModel
(Conformalized Quantile Regression, source): Calibrates the quantile predictions from a probabilistic forecasting model.symmetric=True
:incs_qr(symmetric=True)
(Quantile Regression Non-Conformity Score) to compute the non-conformity scoressymmetric=False
incs_qr(symmetric=False)
(Quantile Regression Non-Conformity Score) to compute the non-conformity scores for the upper and lower bound separately.